{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Analyze Classification Results\n",
    "\n",
    "This notebook is meant to be run after evaluating a trained classifier model. Copy this notebook to the specific directory where the classifier model logs were saved:\n",
    "\n",
    "```\n",
    "CameraTraps/\n",
    "    classification/\n",
    "        BASE_LOGDIR/\n",
    "            classification_ds.csv\n",
    "            label_index.json\n",
    "            LOGDIR/\n",
    "                analyze_classification_results.ipynb  # COPY THIS NOTEBOOK TO HERE\n",
    "\n",
    "                # files created by train_classifier.py\n",
    "                ckpt_XX.pt\n",
    "                events.out.tfevents...\n",
    "                params.json\n",
    "\n",
    "                # files created by evaluate_classifier.py\n",
    "                confusion_matrices.npz\n",
    "                label_stats.csv\n",
    "                outputs_{split}.csv.gz\n",
    "                overall_metrics.csv\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline\n",
    "!pwd"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports and Constants"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "from typing import Optional, Sequence\n",
    "\n",
    "from IPython.display import Image\n",
    "import matplotlib.figure\n",
    "import matplotlib.image as mpimg\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import sklearn.metrics\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "from visualization import plot_utils\n",
    "\n",
    "\n",
    "SPLITS = ['train', 'val', 'test']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot Confusion Matrices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('../label_index.json', 'r') as f:\n",
    "    idx_to_label = json.load(f)\n",
    "label_names = [idx_to_label[str(i)] for i in range(len(idx_to_label))]\n",
    "\n",
    "# in the case of analyzing MegaClassifier's outputs,\n",
    "# we may need to add an extra 'other' category\n",
    "# label_names_no_other = list(label_names)\n",
    "# label_names.append('other')\n",
    "\n",
    "cms = np.load('confusion_matrices.npz')\n",
    "for split in SPLITS:\n",
    "    if split not in cms:\n",
    "        print(f'Split {split} not found in confusion matrices npz file.')\n",
    "        continue\n",
    "    print(split)\n",
    "    fig = plot_utils.plot_confusion_matrix(cms[split], classes=label_names, normalize=True, fmt='{:.1f}')\n",
    "    fig.set_facecolor('white')\n",
    "    display(fig)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load model outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output_df = {}\n",
    "for split in SPLITS:\n",
    "    if os.path.exists(f'outputs_{split}.csv.gz'):\n",
    "        output_df[split] = pd.read_csv(f'outputs_{split}.csv.gz')\n",
    "        output_df[split]['pred'] = output_df[split][label_names].idxmax(axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "label_stats = pd.read_csv('label_stats.csv')\n",
    "label_stats.set_index(['split', 'label'], inplace=True)\n",
    "for split in output_df:\n",
    "    label_stats.loc[(split, 'mean'), :] = label_stats.loc[split, :].mean()\n",
    "\n",
    "    # for MegaClassifier, we might want to take the mean excluding the \"other\" category\n",
    "    # label_stats.loc[(split, 'mean (excluding other)'), :] = (\n",
    "    #     label_stats.loc[(split, label_names_no_other), :].mean()\n",
    "    # )\n",
    "display(label_stats.unstack('split'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_perf = label_stats.loc['test', :].copy()\n",
    "test_perf['count'] = output_df['test'].groupby('label').size()\n",
    "display(test_perf)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot classifier calibration on test set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ncols = 3\n",
    "nrows = int(np.ceil(len(label_names) / ncols))\n",
    "fig = matplotlib.figure.Figure(figsize=(ncols * 5, nrows * 5), tight_layout=True, facecolor='white')\n",
    "axs = fig.subplots(nrows, ncols, squeeze=False)\n",
    "for i, label_name in enumerate(label_names):\n",
    "    r, c = i // ncols, i % ncols\n",
    "    ax = axs[r, c]\n",
    "    mask = output_df['test']['pred'] == label_name\n",
    "    plot_utils.plot_calibration_curve(\n",
    "        true_scores=(output_df['test'].loc[mask, 'label'] == output_df['test'].loc[mask, 'pred']),\n",
    "        pred_scores=output_df['test'].loc[mask, label_name],\n",
    "        num_bins=20, name=label_name, ax=ax)\n",
    "    ax.legend()\n",
    "\n",
    "# hide unused axes\n",
    "for i in range(len(label_names), nrows * ncols):\n",
    "    ax = axs[i // ncols, i % ncols]\n",
    "    ax.set_axis_off()\n",
    "\n",
    "display(fig)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Plot confusion matrix for test set if we set prediction confidence threshold at 0.99"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_df = output_df['test']\n",
    "test_df['pred_conf'] = test_df.lookup(\n",
    "    row_labels=range(len(test_df)), col_labels=test_df['pred'])\n",
    "\n",
    "mask = test_df['pred_conf'] >= 0.99\n",
    "cm = sklearn.metrics.confusion_matrix(\n",
    "    y_true=test_df.loc[mask, 'label'],\n",
    "    y_pred=test_df.loc[mask, 'pred'],\n",
    "    labels=label_names)\n",
    "\n",
    "fig = plot_utils.plot_confusion_matrix(cm, classes=label_names, normalize=True, fmt='{:.1f}')\n",
    "fig.set_facecolor('white')\n",
    "display(fig)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_img_grid(paths: Sequence[str], ncols: int, size: float) -> matplotlib.figure.Figure:\n",
    "    \"\"\"Plot a grid of square images.\n",
    "\n",
    "    Args:\n",
    "        paths: list of str, paths to image crops\n",
    "        ncols: int, number of columns for output figure\n",
    "        size: float, size (in inches) of each row/column\n",
    "\n",
    "    Returns: matplotlib Figure\n",
    "    \"\"\"\n",
    "    DPI = 113\n",
    "    nrows = int(np.ceil(len(paths) / ncols))\n",
    "    fig = matplotlib.figure.Figure(figsize=(ncols * size / DPI, nrows * size / DPI))\n",
    "    axs = fig.subplots(nrows, ncols, squeeze=False)\n",
    "    for i, path in tqdm(enumerate(paths)):\n",
    "        r, c = i // ncols, i % ncols\n",
    "        ax = axs[r, c]\n",
    "        img = mpimg.imread(path)\n",
    "        ax.imshow(img)\n",
    "    for r in range(nrows):\n",
    "        for c in range(ncols):\n",
    "            axs[r, c].set_axis_off()\n",
    "            axs[r, c].set_aspect('equal')\n",
    "    fig.subplots_adjust(wspace=0, hspace=0)\n",
    "    return fig\n",
    "\n",
    "def plot_images_groupby_pred(df: pd.DataFrame, ncols: int, size: int,\n",
    "                             cropped_images_dir: str = '/ssd/crops_sq',\n",
    "                             count: Optional[int] = None) -> None:\n",
    "    \"\"\"Creates one figure for each prediction.\n",
    "\n",
    "    Args:\n",
    "        df: pd.DataFrame, classifier output dataframe, all examples belong\n",
    "            to the same label\n",
    "        ncols: int, number of columns for output figure\n",
    "        size: float, size (in inches) of each row/column\n",
    "        cropped_images_dir: str, path to cropped images\n",
    "        count: optional int, limit on number of images to show for each label\n",
    "    \"\"\"\n",
    "    for name, pred_df in df.groupby('pred'):\n",
    "        print(name)\n",
    "        if count is not None and len(pred_df) > count:\n",
    "            print(f'Original count: {len(pred_df)}. Sampling {count}')\n",
    "            pred_df = pred_df.sample(count)\n",
    "        image_crop_paths = pred_df['path'].map(lambda x: os.path.join(cropped_images_dir, x))\n",
    "        fig = plot_img_grid(image_crop_paths, ncols=ncols, size=size)\n",
    "        display(fig)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot all pronghorn images from test set, and show top predicted label\n",
    "label = 'pronghorn'\n",
    "df = output_df['test']\n",
    "plot_images_groupby_pred(df.loc[df['label'] == label], ncols=8, size=224, count=40)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:cameratraps-classifier] *",
   "language": "python",
   "name": "conda-env-cameratraps-classifier-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
